Rice Variety Classification¶
This notebook implements a deep learning pipeline to classify rice varieties based on image data. We evaluate three models: a basic ANN, a deep fully connected DNN, and a CNN.
Environment Setup¶
- Import required packages and prepare helper functions for visualization and metrics.
- Create helper functions to help with visualization and analysis
In [ ]:
import os, sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import roc_curve, auc, classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import label_binarize
from sklearn.utils import resample
# Reproducible results
np.random.seed(42)
# Misc.
plt.style.use('ggplot')
In [2]:
def plot_learning_curve(
train_loss, val_loss, train_metric, val_metric,
to_file: str = None
) -> None:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].plot(train_loss, 'r--')
ax[0].plot(val_loss, 'b--')
ax[0].set_xlabel("epochs")
ax[0].set_ylabel("Loss")
ax[0].legend(['train', 'val'])
ax[1].plot(train_metric, 'r--')
ax[1].plot(val_metric, 'b--')
ax[1].set_xlabel('epochs')
ax[1].set_ylabel('Accuracy')
ax[1].axhline(y=0.125, c='g', alpha=0.5) # Random probability - naive classifier
ax[1].legend(['train', 'val', 'random baseline'])
fig.tight_layout()
plt.show()
if to_file is not None:
fig.savefig(to_file)
In [3]:
def visualize_32predictions(model, test_df, label_to_index, to_file=None):
fig, ax = plt.subplots(8, 4, figsize=(20, 20))
ax = ax.ravel()
index_to_label = {v: k for k, v in label_to_index.items()}
sample_df = test_df.sample(32, random_state=42).reset_index(drop=True)
for i in range(32):
img_array = sample_df.loc[i, 'image_array']
true_label = sample_df.loc[i, 'variety']
x = np.expand_dims(img_array, axis=0)
pred = model.predict(x, verbose=0)
pred_idx = np.argmax(pred)
pred_label = index_to_label[pred_idx]
ax[i].imshow(img_array)
ax[i].axis('off')
ax[i].set_title(
f"GT: {true_label}\nPred: {pred_label}",
fontsize=8
)
ax[i].text(
0.5, 1.15,
"CORRECT" if pred_label == true_label else "INCORRECT",
transform=ax[i].transAxes,
ha='center',
va='bottom',
fontsize=8,
color='green' if pred_label == true_label else 'red',
weight='bold'
)
plt.tight_layout()
plt.show()
if to_file:
fig.savefig(to_file, bbox_inches='tight')
In [4]:
def plot_roc_auc(y_true, y_pred_probs, class_names):
n_classes = len(class_names)
y_true_bin = label_binarize(y_true, classes=list(range(n_classes)))
fpr, tpr, roc_auc = {}, {}, {}
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_probs[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
plt.figure(figsize=(10, 8))
for i in range(n_classes):
plt.plot(fpr[i], tpr[i], lw=2,
label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')
plt.plot([0, 1], [0, 1], 'k--', lw=1)
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC-AUC Curve')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()
In [5]:
def print_classification_report(y_true, y_pred_probs, class_names):
y_pred = np.argmax(y_pred_probs, axis=1)
report = classification_report(y_true, y_pred, target_names=class_names, digits=4)
print(report)
In [6]:
def plot_confusion_matrix(y_true, y_pred_probs, class_labels, normalize=False):
y_pred = np.argmax(y_pred_probs, axis=1)
cm = confusion_matrix(y_true, y_pred, normalize='true' if normalize else None)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_labels)
disp.plot(xticks_rotation=90)
disp.ax_.grid(False)
plt.show()
Load and Explore Dataset¶
- The metadata is loaded and is used to construct the full image paths. Basic exploratory checks are included to verify image counts per class.
In [7]:
# Load the CSV file
df = pd.read_csv("../data/meta_train.csv")
# Construct full image paths
df["image_path"] = df.apply(
lambda row: os.path.join("../data/train_images", row["label"], row["image_id"]),
axis=1,
)
# Basic info
print("Total records:", len(df))
df.head()
Total records: 10407
Out[7]:
| image_id | label | variety | age | image_path | |
|---|---|---|---|---|---|
| 0 | 100330.jpg | bacterial_leaf_blight | ADT45 | 45 | ../data/train_images\bacterial_leaf_blight\100... |
| 1 | 100365.jpg | bacterial_leaf_blight | ADT45 | 45 | ../data/train_images\bacterial_leaf_blight\100... |
| 2 | 100382.jpg | bacterial_leaf_blight | ADT45 | 45 | ../data/train_images\bacterial_leaf_blight\100... |
| 3 | 100632.jpg | bacterial_leaf_blight | ADT45 | 45 | ../data/train_images\bacterial_leaf_blight\100... |
| 4 | 101918.jpg | bacterial_leaf_blight | ADT45 | 45 | ../data/train_images\bacterial_leaf_blight\101... |
Data Balancing¶
- Some rice varieties are underrepresented, which could negatively impact the training process of the model. To address this imbalance, upsampling was performed on
varietyto equalize the number of images per variety.
In [8]:
def upsample(df, target_col='variety'):
# Determine the max class count
max_size = df[target_col].value_counts().max()
frames = []
for variety, group in df.groupby(target_col):
if len(group) < max_size:
group_upsampled = resample(
group,
replace=True,
n_samples=max_size,
random_state=42
)
frames.append(group_upsampled)
else:
frames.append(group)
df_upsampled = pd.concat(frames).sample(frac=1, random_state=42).reset_index(drop=True)
return df_upsampled
# Upsample the data
df = upsample(df, target_col='variety')
In [9]:
df['variety'].value_counts().plot(kind='barh', figsize=(8, 6), title='Number of Images per Variety Label')
plt.xlabel('Image Count')
plt.show()
Train-Validation Split¶
- The dataset is split into 80% training and 20% validation subsets. Stratified sampling ensures that the class distribution is preserved across both sets.
In [10]:
# 80% train, 20% validation
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, shuffle=True, stratify=df['variety'])
print("Training samples:", len(train_df))
print("Validation samples:", len(val_df))
Training samples: 55936 Validation samples: 13984
Image Preprocessing¶
- Each image is resized, center-cropped, normalized, and converted into a NumPy array for model input.
- Grayscale or incompatible images are skipped.
- The
remove_transparencyfunction ensures that images with alpha channels (e.g., PNGs with transparency) are converted to standard RGB format. This avoids issues during tensor conversion and model input preparation.
Transparency removal
In [11]:
def remove_transparency(image: Image) -> Image:
if image.mode in ('RGBA', 'RGBa', 'LA', 'La', 'PA', 'P'):
if image.mode != 'RGBA':
image = image.convert('RGBA')
image = image.convert('RGB')
return image
Crop resizing
In [12]:
def resize_crop(image: Image, width: int, height: int) -> Image:
original_aspect_ratio = image.width / image.height
target_aspect_ratio = width / height
if original_aspect_ratio > target_aspect_ratio:
# Crop horizontally
new_width = int(image.height * target_aspect_ratio)
left = (image.width - new_width) // 2
upper = 0
right = left + new_width
lower = image.height
else:
# Crop vertically
new_height = int(image.width / target_aspect_ratio)
left = 0
upper = (image.height - new_height) // 2
right = image.width
lower = upper + new_height
cropped_image = image.crop((left, upper, right, lower))
resized_image = cropped_image.resize((width, height), Image.Resampling.LANCZOS)
return resized_image
Normalize Image
In [13]:
def normalize_pixels(image: Image) -> Image:
image_array = np.array(image)
normalized_image_array = image_array / 255.0 # Normalize pixel values to the range [0, 1]
return Image.fromarray((normalized_image_array * 255).astype(np.uint8))
In [14]:
def image_preprocessing(df, width, height):
images = []
valid_rows = []
for _, row in tqdm(df.iterrows(), total=len(df), desc="Preprocessing images"):
img_path = os.path.normpath(row['image_path'])
try:
with Image.open(img_path) as im:
if im.mode == 'L':
continue
im = remove_transparency(im)
im = resize_crop(im, width, height)
im = normalize_pixels(im)
images.append(np.array(im)) # Store as NumPy array
valid_rows.append(row)
except Exception as e:
print(f"Failed to process {img_path}: {e}")
df_valid = pd.DataFrame(valid_rows).reset_index(drop=True)
df_valid['image_array'] = images
return df_valid
In [15]:
train_df = image_preprocessing(train_df, 128, 128)
val_df = image_preprocessing(val_df, 128, 128)
Preprocessing images: 100%|██████████| 55936/55936 [07:59<00:00, 116.68it/s] Preprocessing images: 100%|██████████| 13984/13984 [02:40<00:00, 87.09it/s]
Prepare Data for Training¶
- Transform image arrays and labels into
xandyformat, including one-hot encoding for labels.
In [16]:
# Get unique class labels and map them to integers
class_names = sorted(train_df['variety'].unique())
label_to_index = {label: idx for idx, label in enumerate(class_names)}
def extract_xy(df):
x = np.stack(df['image_array'].values)
y = np.array([label_to_index[label] for label in df['variety']])
y = to_categorical(y, num_classes=len(class_names)) # one-hot encode
return x, y
In [17]:
x_train, y_train = extract_xy(train_df)
x_val, y_val = extract_xy(val_df)
In [18]:
print("x_train shape:", x_train.shape)
print("x_val shape:", x_val.shape)
x_train shape: (55936, 128, 128, 3) x_val shape: (13984, 128, 128, 3)
Model 1: ANN (Artificial Neural Network)¶
A simple ANN with dense layers and a low learning rate is used as a baseline. Early stopping is applied to prevent overfitting.
In [19]:
ann_model = tf.keras.Sequential([
layers.Flatten(input_shape=[128,128,3]),
layers.Rescaling(1./255),
layers.Dense(128, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(128, activation='sigmoid'),
layers.Dense(10, activation='softmax'),
])
ann_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6),
loss="categorical_crossentropy",
metrics=['categorical_accuracy']
)
ann_model.summary()
d:\Python312\Lib\site-packages\keras\src\layers\reshaping\flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(**kwargs)
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ flatten (Flatten) │ (None, 49152) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ rescaling (Rescaling) │ (None, 49152) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 128) │ 6,291,584 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_3 (Dense) │ (None, 128) │ 16,512 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_4 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 6,342,410 (24.19 MB)
Trainable params: 6,342,410 (24.19 MB)
Non-trainable params: 0 (0.00 B)
In [20]:
early_stop_ann = EarlyStopping(
monitor="val_loss", patience=5, restore_best_weights=True
)
history_ann = ann_model.fit(
x_train,
y_train,
validation_data=(x_val, y_val),
epochs=50,
batch_size=32,
callbacks=[early_stop_ann],
verbose=1,
)
Epoch 1/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 243s 136ms/step - categorical_accuracy: 0.1358 - loss: 2.3545 - val_categorical_accuracy: 0.2400 - val_loss: 2.1483 Epoch 2/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 221s 126ms/step - categorical_accuracy: 0.2751 - loss: 2.1033 - val_categorical_accuracy: 0.3950 - val_loss: 1.9771 Epoch 3/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 221s 127ms/step - categorical_accuracy: 0.4356 - loss: 1.9368 - val_categorical_accuracy: 0.5027 - val_loss: 1.8320 Epoch 4/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 220s 126ms/step - categorical_accuracy: 0.5203 - loss: 1.7994 - val_categorical_accuracy: 0.5613 - val_loss: 1.7028 Epoch 5/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 237s 136ms/step - categorical_accuracy: 0.5726 - loss: 1.6703 - val_categorical_accuracy: 0.5973 - val_loss: 1.5900 Epoch 6/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 218s 124ms/step - categorical_accuracy: 0.6011 - loss: 1.5664 - val_categorical_accuracy: 0.6201 - val_loss: 1.4919 Epoch 7/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 214s 122ms/step - categorical_accuracy: 0.6280 - loss: 1.4626 - val_categorical_accuracy: 0.6418 - val_loss: 1.4090 Epoch 8/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 245s 140ms/step - categorical_accuracy: 0.6481 - loss: 1.3872 - val_categorical_accuracy: 0.6532 - val_loss: 1.3380 Epoch 9/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 239s 137ms/step - categorical_accuracy: 0.6675 - loss: 1.3116 - val_categorical_accuracy: 0.6751 - val_loss: 1.2734 Epoch 10/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 217s 124ms/step - categorical_accuracy: 0.6863 - loss: 1.2514 - val_categorical_accuracy: 0.6957 - val_loss: 1.2161 Epoch 11/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 217s 124ms/step - categorical_accuracy: 0.7026 - loss: 1.1982 - val_categorical_accuracy: 0.7125 - val_loss: 1.1639 Epoch 12/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 216s 123ms/step - categorical_accuracy: 0.7118 - loss: 1.1487 - val_categorical_accuracy: 0.7257 - val_loss: 1.1150 Epoch 13/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 213s 122ms/step - categorical_accuracy: 0.7266 - loss: 1.1017 - val_categorical_accuracy: 0.7379 - val_loss: 1.0698 Epoch 14/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 214s 123ms/step - categorical_accuracy: 0.7436 - loss: 1.0539 - val_categorical_accuracy: 0.7488 - val_loss: 1.0277 Epoch 15/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 215s 123ms/step - categorical_accuracy: 0.7541 - loss: 1.0119 - val_categorical_accuracy: 0.7590 - val_loss: 0.9883 Epoch 16/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 218s 125ms/step - categorical_accuracy: 0.7676 - loss: 0.9659 - val_categorical_accuracy: 0.7672 - val_loss: 0.9498 Epoch 17/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 215s 123ms/step - categorical_accuracy: 0.7753 - loss: 0.9317 - val_categorical_accuracy: 0.7738 - val_loss: 0.9143 Epoch 18/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 215s 123ms/step - categorical_accuracy: 0.7865 - loss: 0.8941 - val_categorical_accuracy: 0.7842 - val_loss: 0.8781 Epoch 19/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 216s 124ms/step - categorical_accuracy: 0.7919 - loss: 0.8553 - val_categorical_accuracy: 0.7920 - val_loss: 0.8469 Epoch 20/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 217s 124ms/step - categorical_accuracy: 0.8009 - loss: 0.8276 - val_categorical_accuracy: 0.7965 - val_loss: 0.8149 Epoch 21/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 215s 123ms/step - categorical_accuracy: 0.8079 - loss: 0.7960 - val_categorical_accuracy: 0.8071 - val_loss: 0.7850 Epoch 22/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 215s 123ms/step - categorical_accuracy: 0.8188 - loss: 0.7655 - val_categorical_accuracy: 0.8139 - val_loss: 0.7589 Epoch 23/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 206s 118ms/step - categorical_accuracy: 0.8219 - loss: 0.7478 - val_categorical_accuracy: 0.8266 - val_loss: 0.7307 Epoch 24/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 206s 118ms/step - categorical_accuracy: 0.8317 - loss: 0.7158 - val_categorical_accuracy: 0.8317 - val_loss: 0.7062 Epoch 25/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 216s 124ms/step - categorical_accuracy: 0.8428 - loss: 0.6785 - val_categorical_accuracy: 0.8377 - val_loss: 0.6818 Epoch 26/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 220s 126ms/step - categorical_accuracy: 0.8467 - loss: 0.6623 - val_categorical_accuracy: 0.8395 - val_loss: 0.6586 Epoch 27/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 221s 126ms/step - categorical_accuracy: 0.8515 - loss: 0.6402 - val_categorical_accuracy: 0.8479 - val_loss: 0.6367 Epoch 28/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 224s 128ms/step - categorical_accuracy: 0.8594 - loss: 0.6123 - val_categorical_accuracy: 0.8550 - val_loss: 0.6165 Epoch 29/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 215s 123ms/step - categorical_accuracy: 0.8641 - loss: 0.5981 - val_categorical_accuracy: 0.8602 - val_loss: 0.5952 Epoch 30/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 218s 125ms/step - categorical_accuracy: 0.8715 - loss: 0.5729 - val_categorical_accuracy: 0.8661 - val_loss: 0.5752 Epoch 31/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 217s 124ms/step - categorical_accuracy: 0.8768 - loss: 0.5535 - val_categorical_accuracy: 0.8729 - val_loss: 0.5567 Epoch 32/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 222s 127ms/step - categorical_accuracy: 0.8822 - loss: 0.5326 - val_categorical_accuracy: 0.8751 - val_loss: 0.5396 Epoch 33/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 219s 125ms/step - categorical_accuracy: 0.8856 - loss: 0.5203 - val_categorical_accuracy: 0.8804 - val_loss: 0.5210 Epoch 34/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 218s 124ms/step - categorical_accuracy: 0.8915 - loss: 0.5008 - val_categorical_accuracy: 0.8839 - val_loss: 0.5052 Epoch 35/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 217s 124ms/step - categorical_accuracy: 0.8951 - loss: 0.4863 - val_categorical_accuracy: 0.8858 - val_loss: 0.4899 Epoch 36/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 217s 124ms/step - categorical_accuracy: 0.9007 - loss: 0.4684 - val_categorical_accuracy: 0.8957 - val_loss: 0.4732 Epoch 37/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 218s 124ms/step - categorical_accuracy: 0.9056 - loss: 0.4539 - val_categorical_accuracy: 0.8965 - val_loss: 0.4609 Epoch 38/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 217s 124ms/step - categorical_accuracy: 0.9114 - loss: 0.4350 - val_categorical_accuracy: 0.9035 - val_loss: 0.4433 Epoch 39/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 218s 125ms/step - categorical_accuracy: 0.9142 - loss: 0.4182 - val_categorical_accuracy: 0.9067 - val_loss: 0.4306 Epoch 40/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 225s 129ms/step - categorical_accuracy: 0.9183 - loss: 0.4052 - val_categorical_accuracy: 0.9075 - val_loss: 0.4162 Epoch 41/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 215s 123ms/step - categorical_accuracy: 0.9221 - loss: 0.3900 - val_categorical_accuracy: 0.9139 - val_loss: 0.4025 Epoch 42/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 220s 126ms/step - categorical_accuracy: 0.9237 - loss: 0.3768 - val_categorical_accuracy: 0.9160 - val_loss: 0.3910 Epoch 43/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 214s 123ms/step - categorical_accuracy: 0.9266 - loss: 0.3674 - val_categorical_accuracy: 0.9181 - val_loss: 0.3794 Epoch 44/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 217s 124ms/step - categorical_accuracy: 0.9284 - loss: 0.3576 - val_categorical_accuracy: 0.9217 - val_loss: 0.3669 Epoch 45/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 217s 124ms/step - categorical_accuracy: 0.9310 - loss: 0.3424 - val_categorical_accuracy: 0.9219 - val_loss: 0.3562 Epoch 46/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 268s 127ms/step - categorical_accuracy: 0.9342 - loss: 0.3313 - val_categorical_accuracy: 0.9234 - val_loss: 0.3454 Epoch 47/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 218s 125ms/step - categorical_accuracy: 0.9350 - loss: 0.3248 - val_categorical_accuracy: 0.9263 - val_loss: 0.3350 Epoch 48/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 215s 123ms/step - categorical_accuracy: 0.9364 - loss: 0.3125 - val_categorical_accuracy: 0.9306 - val_loss: 0.3254 Epoch 49/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 216s 124ms/step - categorical_accuracy: 0.9400 - loss: 0.3000 - val_categorical_accuracy: 0.9343 - val_loss: 0.3158 Epoch 50/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 218s 125ms/step - categorical_accuracy: 0.9420 - loss: 0.2917 - val_categorical_accuracy: 0.9366 - val_loss: 0.3068
In [21]:
plot_learning_curve(
history_ann.history['loss'],
history_ann.history['val_loss'],
history_ann.history['categorical_accuracy'],
history_ann.history['val_categorical_accuracy'],
)
In [24]:
visualize_32predictions(
ann_model,
val_df,
label_to_index,
)
In [25]:
ann_train_loss, ann_train_acc = ann_model.evaluate(x_train, y_train, verbose=1)
ann_val_loss, ann_val_acc = ann_model.evaluate(x_val, y_val, verbose=1)
print(f"Train Accuracy: {ann_train_acc:.4f}, Loss: {ann_train_loss:.4f}")
print(f"Val Accuracy: {ann_val_acc:.4f}, Loss: {ann_val_loss:.4f}")
1748/1748 ━━━━━━━━━━━━━━━━━━━━ 39s 22ms/step - categorical_accuracy: 0.9457 - loss: 0.2851 437/437 ━━━━━━━━━━━━━━━━━━━━ 10s 22ms/step - categorical_accuracy: 0.9347 - loss: 0.3082 Train Accuracy: 0.9463, Loss: 0.2841 Val Accuracy: 0.9366, Loss: 0.3068
In [26]:
ann_y_pred_probs = ann_model.predict(x_val, verbose=1)
ann_y_true = np.argmax(y_val, axis=1)
plot_roc_auc(ann_y_true, ann_y_pred_probs, list(label_to_index.keys()))
437/437 ━━━━━━━━━━━━━━━━━━━━ 9s 21ms/step
In [27]:
print_classification_report(ann_y_true, ann_y_pred_probs, list(label_to_index.keys()))
precision recall f1-score support
ADT45 0.8677 0.5629 0.6829 1398
AndraPonni 0.9060 1.0000 0.9507 1398
AtchayaPonni 0.9021 0.9557 0.9281 1398
IR20 0.9908 1.0000 0.9954 1399
KarnatakaPonni 0.8872 0.8992 0.8931 1399
Onthanel 0.9700 0.9936 0.9816 1398
Ponni 0.8590 0.9621 0.9076 1399
RR 1.0000 1.0000 1.0000 1398
Surya 1.0000 1.0000 1.0000 1398
Zonal 0.9720 0.9921 0.9820 1399
accuracy 0.9366 13984
macro avg 0.9355 0.9366 0.9321 13984
weighted avg 0.9355 0.9366 0.9321 13984
In [28]:
plot_confusion_matrix(ann_y_true, ann_y_pred_probs, list(label_to_index.keys()))
DNN Model¶
In [20]:
dnn_model = tf.keras.Sequential([
# Input layer: Flatten image
layers.Input(shape=[128,128,3]),
layers.Rescaling(1./255),
layers.Flatten(),
# Hidden layers
layers.Dense(1024, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.3),
layers.Dense(512, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.3),
layers.Dense(256, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.3),
layers.Dense(128, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.3),
# Output layer
layers.Dense(10, activation='softmax')
])
# Compile
dnn_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss='categorical_crossentropy',
metrics=['categorical_accuracy']
)
dnn_model.summary()
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ rescaling_1 (Rescaling) │ (None, 128, 128, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ flatten_1 (Flatten) │ (None, 49152) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_4 (Dense) │ (None, 1024) │ 50,332,672 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_3 │ (None, 1024) │ 4,096 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_3 (Dropout) │ (None, 1024) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_5 (Dense) │ (None, 512) │ 524,800 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_4 │ (None, 512) │ 2,048 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_4 (Dropout) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_6 (Dense) │ (None, 256) │ 131,328 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_5 │ (None, 256) │ 1,024 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_5 (Dropout) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_7 (Dense) │ (None, 128) │ 32,896 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_6 │ (None, 128) │ 512 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_6 (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_8 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 51,030,666 (194.67 MB)
Trainable params: 51,026,826 (194.65 MB)
Non-trainable params: 3,840 (15.00 KB)
In [21]:
early_stop_dnn = EarlyStopping(
monitor="val_loss", patience=5, restore_best_weights=True
)
history_dnn = dnn_model.fit(
x_train,
y_train,
validation_data=(x_val, y_val),
epochs=50,
batch_size=32,
callbacks=[early_stop_dnn],
verbose=1,
)
Epoch 1/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1192s 677ms/step - categorical_accuracy: 0.5908 - loss: 1.3072 - val_categorical_accuracy: 0.6394 - val_loss: 1.4978 Epoch 2/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1219s 697ms/step - categorical_accuracy: 0.8658 - loss: 0.4142 - val_categorical_accuracy: 0.7114 - val_loss: 0.8204 Epoch 3/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1284s 735ms/step - categorical_accuracy: 0.8918 - loss: 0.3292 - val_categorical_accuracy: 0.7931 - val_loss: 0.6185 Epoch 4/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1296s 741ms/step - categorical_accuracy: 0.9079 - loss: 0.2750 - val_categorical_accuracy: 0.8285 - val_loss: 0.4571 Epoch 5/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1248s 714ms/step - categorical_accuracy: 0.9197 - loss: 0.2390 - val_categorical_accuracy: 0.8302 - val_loss: 0.6066 Epoch 6/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1263s 723ms/step - categorical_accuracy: 0.9377 - loss: 0.1831 - val_categorical_accuracy: 0.9146 - val_loss: 0.2217 Epoch 7/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1203s 688ms/step - categorical_accuracy: 0.9239 - loss: 0.2262 - val_categorical_accuracy: 0.9255 - val_loss: 0.2046 Epoch 8/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1161s 664ms/step - categorical_accuracy: 0.9376 - loss: 0.1851 - val_categorical_accuracy: 0.7534 - val_loss: 1.0889 Epoch 9/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1114s 637ms/step - categorical_accuracy: 0.9361 - loss: 0.1853 - val_categorical_accuracy: 0.7516 - val_loss: 0.7670 Epoch 10/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1106s 632ms/step - categorical_accuracy: 0.9344 - loss: 0.1915 - val_categorical_accuracy: 0.9603 - val_loss: 0.1140 Epoch 11/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1044s 597ms/step - categorical_accuracy: 0.9344 - loss: 0.1905 - val_categorical_accuracy: 0.8628 - val_loss: 0.3820 Epoch 12/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1029s 588ms/step - categorical_accuracy: 0.9277 - loss: 0.2163 - val_categorical_accuracy: 0.9497 - val_loss: 0.1699 Epoch 13/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1028s 588ms/step - categorical_accuracy: 0.9390 - loss: 0.1767 - val_categorical_accuracy: 0.9748 - val_loss: 0.0682 Epoch 14/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1025s 587ms/step - categorical_accuracy: 0.9486 - loss: 0.1527 - val_categorical_accuracy: 0.7765 - val_loss: 0.9778 Epoch 15/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1025s 586ms/step - categorical_accuracy: 0.9470 - loss: 0.1577 - val_categorical_accuracy: 0.9276 - val_loss: 0.2073 Epoch 16/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1043s 597ms/step - categorical_accuracy: 0.9352 - loss: 0.1883 - val_categorical_accuracy: 0.9559 - val_loss: 0.1368 Epoch 17/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1026s 587ms/step - categorical_accuracy: 0.9461 - loss: 0.1596 - val_categorical_accuracy: 0.9667 - val_loss: 0.1156 Epoch 18/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 1025s 586ms/step - categorical_accuracy: 0.9466 - loss: 0.1550 - val_categorical_accuracy: 0.9610 - val_loss: 0.1312
In [22]:
plot_learning_curve(
history_dnn.history['loss'],
history_dnn.history['val_loss'],
history_dnn.history['categorical_accuracy'],
history_dnn.history['val_categorical_accuracy'],
)
In [23]:
visualize_32predictions(
dnn_model,
val_df,
label_to_index,
)
In [24]:
dnn_train_loss, dnn_train_acc = dnn_model.evaluate(x_train, y_train, verbose=1)
dnn_val_loss, dnn_val_acc = dnn_model.evaluate(x_val, y_val, verbose=1)
print(f"Train Accuracy: {dnn_train_acc:.4f}, Loss: {dnn_train_loss:.4f}")
print(f"Val Accuracy: {dnn_val_acc:.4f}, Loss: {dnn_val_loss:.4f}")
1748/1748 ━━━━━━━━━━━━━━━━━━━━ 82s 47ms/step - categorical_accuracy: 0.9856 - loss: 0.0381 437/437 ━━━━━━━━━━━━━━━━━━━━ 19s 44ms/step - categorical_accuracy: 0.9743 - loss: 0.0688 Train Accuracy: 0.9859, Loss: 0.0375 Val Accuracy: 0.9748, Loss: 0.0682
In [25]:
dnn_y_pred_probs = dnn_model.predict(x_val, verbose=1)
dnn_y_true = np.argmax(y_val, axis=1)
plot_roc_auc(dnn_y_true, dnn_y_pred_probs, list(label_to_index.keys()))
437/437 ━━━━━━━━━━━━━━━━━━━━ 18s 42ms/step
In [26]:
print_classification_report(dnn_y_true, dnn_y_pred_probs, list(label_to_index.keys()))
precision recall f1-score support
ADT45 0.8687 0.8805 0.8746 1398
AndraPonni 0.9879 0.9921 0.9900 1398
AtchayaPonni 0.9873 0.9986 0.9929 1398
IR20 1.0000 1.0000 1.0000 1399
KarnatakaPonni 0.9660 0.9743 0.9701 1399
Onthanel 0.9804 1.0000 0.9901 1398
Ponni 0.9642 1.0000 0.9818 1399
RR 1.0000 1.0000 1.0000 1398
Surya 1.0000 1.0000 1.0000 1398
Zonal 0.9968 0.9021 0.9471 1399
accuracy 0.9748 13984
macro avg 0.9751 0.9748 0.9747 13984
weighted avg 0.9751 0.9748 0.9747 13984
In [27]:
plot_confusion_matrix(dnn_y_true, dnn_y_pred_probs, list(label_to_index.keys()))
CNN Model¶
In [28]:
cnn_model = tf.keras.Sequential([
tf.keras.Input(shape=(128, 128, 3)),
tf.keras.layers.Rescaling(1./255),
tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
cnn_model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
cnn_model.summary()
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ rescaling_2 (Rescaling) │ (None, 128, 128, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d (Conv2D) │ (None, 128, 128, 32) │ 896 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 64, 64, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 64, 64, 32) │ 9,248 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 32, 32, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_2 (Conv2D) │ (None, 32, 32, 64) │ 18,496 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_2 (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (Conv2D) │ (None, 16, 16, 128) │ 73,856 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ flatten_2 (Flatten) │ (None, 8192) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_7 (Dropout) │ (None, 8192) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_9 (Dense) │ (None, 128) │ 1,048,704 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_10 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 1,152,490 (4.40 MB)
Trainable params: 1,152,490 (4.40 MB)
Non-trainable params: 0 (0.00 B)
In [32]:
tf.keras.backend.clear_session()
WARNING:tensorflow:From d:\Python312\Lib\site-packages\keras\src\backend\common\global_state.py:82: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead.
In [29]:
early_stop_cnn = EarlyStopping(
monitor="val_loss", patience=5, restore_best_weights=True
)
history_cnn = cnn_model.fit(
x_train,
y_train,
validation_data=(x_val, y_val),
epochs=50,
batch_size=32,
callbacks=[early_stop_cnn],
verbose=1,
)
Epoch 1/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 379s 215ms/step - accuracy: 0.7509 - loss: 0.7378 - val_accuracy: 0.9752 - val_loss: 0.0753 Epoch 2/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 369s 211ms/step - accuracy: 0.9814 - loss: 0.0577 - val_accuracy: 0.9921 - val_loss: 0.0245 Epoch 3/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 360s 206ms/step - accuracy: 0.9895 - loss: 0.0327 - val_accuracy: 0.9923 - val_loss: 0.0274 Epoch 4/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 364s 208ms/step - accuracy: 0.9913 - loss: 0.0260 - val_accuracy: 0.9959 - val_loss: 0.0132 Epoch 5/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 363s 208ms/step - accuracy: 0.9941 - loss: 0.0180 - val_accuracy: 0.9935 - val_loss: 0.0243 Epoch 6/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 363s 208ms/step - accuracy: 0.9936 - loss: 0.0199 - val_accuracy: 0.9965 - val_loss: 0.0139 Epoch 7/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 365s 209ms/step - accuracy: 0.9943 - loss: 0.0191 - val_accuracy: 0.9962 - val_loss: 0.0150 Epoch 8/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 364s 208ms/step - accuracy: 0.9961 - loss: 0.0133 - val_accuracy: 0.9944 - val_loss: 0.0236 Epoch 9/50 1748/1748 ━━━━━━━━━━━━━━━━━━━━ 377s 216ms/step - accuracy: 0.9954 - loss: 0.0148 - val_accuracy: 0.9938 - val_loss: 0.0223
In [30]:
plot_learning_curve(
history_cnn.history['loss'],
history_cnn.history['val_loss'],
history_cnn.history['accuracy'],
history_cnn.history['val_accuracy'],
)
In [31]:
visualize_32predictions(
cnn_model,
val_df,
label_to_index,
)
In [32]:
cnn_train_loss, cnn_train_acc = cnn_model.evaluate(x_train, y_train, verbose=1)
cnn_val_loss, cnn_val_acc = cnn_model.evaluate(x_val, y_val, verbose=1)
print(f"Train Accuracy: {cnn_train_acc:.4f}, Loss: {cnn_train_loss:.4f}")
print(f"Val Accuracy: {cnn_val_acc:.4f}, Loss: {cnn_val_loss:.4f}")
1748/1748 ━━━━━━━━━━━━━━━━━━━━ 94s 54ms/step - accuracy: 0.9986 - loss: 0.0038 437/437 ━━━━━━━━━━━━━━━━━━━━ 25s 57ms/step - accuracy: 0.9961 - loss: 0.0115 Train Accuracy: 0.9987, Loss: 0.0040 Val Accuracy: 0.9959, Loss: 0.0132
In [33]:
cnn_y_pred_probs = cnn_model.predict(x_val, verbose=1)
cnn_y_true = np.argmax(y_val, axis=1)
plot_roc_auc(cnn_y_true, cnn_y_pred_probs, list(label_to_index.keys()))
437/437 ━━━━━━━━━━━━━━━━━━━━ 23s 53ms/step
In [34]:
print_classification_report(cnn_y_true, cnn_y_pred_probs, list(label_to_index.keys()))
precision recall f1-score support
ADT45 0.9912 0.9678 0.9794 1398
AndraPonni 0.9986 1.0000 0.9993 1398
AtchayaPonni 0.9957 1.0000 0.9979 1398
IR20 0.9971 1.0000 0.9986 1399
KarnatakaPonni 0.9964 0.9943 0.9953 1399
Onthanel 0.9950 0.9979 0.9964 1398
Ponni 0.9873 0.9993 0.9933 1399
RR 0.9993 1.0000 0.9996 1398
Surya 1.0000 1.0000 1.0000 1398
Zonal 0.9986 1.0000 0.9993 1399
accuracy 0.9959 13984
macro avg 0.9959 0.9959 0.9959 13984
weighted avg 0.9959 0.9959 0.9959 13984
In [35]:
plot_confusion_matrix(cnn_y_true, cnn_y_pred_probs, list(label_to_index.keys()))
In [ ]:
cnn_model.save("../models/cnn_variety_prediction_model.keras", overwrite = True)